{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ranking_crime.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sK-RhlsGxwpd",
        "colab_type": "text"
      },
      "source": [
        "##### Copyright 2020 Google LLC.\n",
        "\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the 'License');\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an 'AS IS' BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fuN9k6Gux-yD",
        "colab_type": "text"
      },
      "source": [
        "This colab contains TensorFlow code for implementing the constrained optimization methods presented in the paper:\n",
        "> Harikrishna Narasimhan, Andrew Cotter, Maya Gupta, Serena Wang, 'Pairwise Fairness for Ranking and Regression', AAAI 2020. [<a href='https://arxiv.org/pdf/1906.05330.pdf'>link</a>]\n",
        "\n",
        "First, let's install and import the relevant libraries."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JXgLyAJm0UyB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import random\n",
        "import sys\n",
        "from sklearn import model_selection\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iTUyZk_A0XnF",
        "colab_type": "text"
      },
      "source": [
        "We will need the TensorFlow Constrained Optimization (TFCO) library."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DvhGP5TW0V_J",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install git+https://github.com/google-research/tensorflow_constrained_optimization"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iYxlPVVtzWs3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow_constrained_optimization as tfco"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gBUr48pLzsqK",
        "colab_type": "text"
      },
      "source": [
        "## Pairwise Ranking Fairness"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Wgq7N73rPHV",
        "colab_type": "text"
      },
      "source": [
        "We will be training a linear ranking model $f(x) = w^\\top x$ where $x \\in \\mathbb{R}^d$ is a set of features for a query-document pair. Our goal is to train the model such that it accurately ranks the positive documents in a query above the negative ones.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6tZrx9BfOB_Q",
        "colab_type": "text"
      },
      "source": [
        "Specifically, for the ranking model $f$, we denote:\n",
        "- $err(f)$ as the pairwise ranking error for model $f$ over all pairs of positive and negative documents\n",
        "$$\n",
        "err(f) = \\mathbf{E}\\big[\\mathbb{I}\\big(f(x) < f(x')\\big) \\,\\big|\\, y = 1,~ y' = 0\\big]\n",
        "$$\n",
        "\n",
        "\n",
        "- $err_{i,j}(f)$ as the pairwise ranking error over positive-negative document pairs where the pos. document is from group $i$, and the neg. document is from group $j$.\n",
        "\n",
        "$$\n",
        "err_{i, j}(f) = \\mathbf{E}\\big[\\mathbb{I}\\big(f(x) < f(x')\\big) \\,\\big|\\, y = 1,~ y' = 0,~ grp(x) = i, ~grp(x') = j\\big]\n",
        "$$\n",
        "<br>\n",
        "\n",
        "We then wish to solve the following constrained problem:\n",
        "$$min_f\\; err(f)$$\n",
        "$$\\text{   s.t.   } |err_{i,j}(f) - err_{k,\\ell}(f)| \\leq \\epsilon \\;\\;\\; \\forall ((i,j), (k,\\ell)) \\in \\mathcal{G},$$\n",
        "\n",
        "where $\\mathcal{G}$ contains the pairs we are interested in constraining."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qM-PzAuykOmN",
        "colab_type": "text"
      },
      "source": [
        "## Load Communities & Crime Data\n",
        "\n",
        "We will use the benchmark Communities and Crimes dataset from the UCI Machine Learning repository for our illustration. This dataset contains various demographic and racial distribution details (aggregated from census and law enforcement data sources) about different communities in the US, along with the per capita crime rate in each commmunity. As is commonly done in the literature, we will bin the crime rate attribute into two categories: \"low crime\" and \"high crime\", and formulate the task of *ranking* the communities such that the high crime ones are above the low crime ones. We consider communities where the percentage of black population is above the 70-th percentile as the protected group.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QSUkaGKxBa2M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# We will divide the data into 10 batches, and treat each of them as a query.\n",
        "num_queries = 10\n",
        "\n",
        "# List of column names in the dataset.\n",
        "column_names = [\"state\", \"county\", \"community\", \"communityname\", \"fold\", \"population\", \"householdsize\", \"racepctblack\", \"racePctWhite\", \"racePctAsian\", \"racePctHisp\", \"agePct12t21\", \"agePct12t29\", \"agePct16t24\", \"agePct65up\", \"numbUrban\", \"pctUrban\", \"medIncome\", \"pctWWage\", \"pctWFarmSelf\", \"pctWInvInc\", \"pctWSocSec\", \"pctWPubAsst\", \"pctWRetire\", \"medFamInc\", \"perCapInc\", \"whitePerCap\", \"blackPerCap\", \"indianPerCap\", \"AsianPerCap\", \"OtherPerCap\", \"HispPerCap\", \"NumUnderPov\", \"PctPopUnderPov\", \"PctLess9thGrade\", \"PctNotHSGrad\", \"PctBSorMore\", \"PctUnemployed\", \"PctEmploy\", \"PctEmplManu\", \"PctEmplProfServ\", \"PctOccupManu\", \"PctOccupMgmtProf\", \"MalePctDivorce\", \"MalePctNevMarr\", \"FemalePctDiv\", \"TotalPctDiv\", \"PersPerFam\", \"PctFam2Par\", \"PctKids2Par\", \"PctYoungKids2Par\", \"PctTeen2Par\", \"PctWorkMomYoungKids\", \"PctWorkMom\", \"NumIlleg\", \"PctIlleg\", \"NumImmig\", \"PctImmigRecent\", \"PctImmigRec5\", \"PctImmigRec8\", \"PctImmigRec10\", \"PctRecentImmig\", \"PctRecImmig5\", \"PctRecImmig8\", \"PctRecImmig10\", \"PctSpeakEnglOnly\", \"PctNotSpeakEnglWell\", \"PctLargHouseFam\", \"PctLargHouseOccup\", \"PersPerOccupHous\", \"PersPerOwnOccHous\", \"PersPerRentOccHous\", \"PctPersOwnOccup\", \"PctPersDenseHous\", \"PctHousLess3BR\", \"MedNumBR\", \"HousVacant\", \"PctHousOccup\", \"PctHousOwnOcc\", \"PctVacantBoarded\", \"PctVacMore6Mos\", \"MedYrHousBuilt\", \"PctHousNoPhone\", \"PctWOFullPlumb\", \"OwnOccLowQuart\", \"OwnOccMedVal\", \"OwnOccHiQuart\", \"RentLowQ\", \"RentMedian\", \"RentHighQ\", \"MedRent\", \"MedRentPctHousInc\", \"MedOwnCostPctInc\", \"MedOwnCostPctIncNoMtg\", \"NumInShelters\", \"NumStreet\", \"PctForeignBorn\", \"PctBornSameState\", \"PctSameHouse85\", \"PctSameCity85\", \"PctSameState85\", \"LemasSwornFT\", \"LemasSwFTPerPop\", \"LemasSwFTFieldOps\", \"LemasSwFTFieldPerPop\", \"LemasTotalReq\", \"LemasTotReqPerPop\", \"PolicReqPerOffic\", \"PolicPerPop\", \"RacialMatchCommPol\", \"PctPolicWhite\", \"PctPolicBlack\", \"PctPolicHisp\", \"PctPolicAsian\", \"PctPolicMinor\", \"OfficAssgnDrugUnits\", \"NumKindsDrugsSeiz\", \"PolicAveOTWorked\", \"LandArea\", \"PopDens\", \"PctUsePubTrans\", \"PolicCars\", \"PolicOperBudg\", \"LemasPctPolicOnPatr\", \"LemasGangUnitDeploy\", \"LemasPctOfficDrugUn\", \"PolicBudgPerPop\", \"ViolentCrimesPerPop\"]\n",
        "\n",
        "dataset_url = \"http://archive.ics.uci.edu/ml/machine-learning-databases/communities/communities.data\"\n",
        "\n",
        "# Read dataset from the UCI web repository and assign column names.\n",
        "data_df = pd.read_csv(dataset_url, sep=\",\", names=column_names,\n",
        "                      na_values=\"?\")\n",
        "\n",
        "# Make sure that there are no missing values in the \"ViolentCrimesPerPop\" column.\n",
        "assert(not data_df[\"ViolentCrimesPerPop\"].isna().any())\n",
        "\n",
        "# Binarize the \"ViolentCrimesPerPop\" column and obtain labels.\n",
        "crime_rate_70_percentile = data_df[\"ViolentCrimesPerPop\"].quantile(q=0.7)\n",
        "labels_df = (data_df[\"ViolentCrimesPerPop\"] >= crime_rate_70_percentile)\n",
        "\n",
        "# Now that we have assigned binary labels, \n",
        "# we drop the \"ViolentCrimesPerPop\" column from the data frame.\n",
        "data_df.drop(columns=\"ViolentCrimesPerPop\", inplace=True)\n",
        "\n",
        "# Group features.\n",
        "race_black_70_percentile = data_df[\"racepctblack\"].quantile(q=0.7)\n",
        "groups_df = (data_df[\"racepctblack\"] >= race_black_70_percentile)\n",
        "\n",
        "# Drop categorical features.\n",
        "data_df.drop(columns=[\"state\", \"county\", \"community\", \"communityname\", \"fold\"],\n",
        "             inplace=True)\n",
        "\n",
        "# Handle missing features.\n",
        "feature_names = data_df.columns\n",
        "for feature_name in feature_names:  \n",
        "    missing_rows = data_df[feature_name].isna()  # Which rows have missing values?\n",
        "    if missing_rows.any():  # Check if at least one row has a missing value.\n",
        "        data_df[feature_name].fillna(0.0, inplace=True)  # Fill NaN with 0.\n",
        "        missing_rows.rename(feature_name + \"_is_missing\", inplace=True)\n",
        "        data_df = data_df.join(missing_rows)  # Append boolean \"is_missing\" feature.\n",
        "\n",
        "labels = labels_df.values.astype(np.float32)\n",
        "groups = groups_df.values.astype(np.float32)\n",
        "features = data_df.values.astype(np.float32)\n",
        "\n",
        "# Set random seed so that the results are reproducible.\n",
        "np.random.seed(123456)\n",
        "\n",
        "# We randomly divide the examples into 'num_queries' queries.\n",
        "queries = np.random.randint(0, num_queries, size=features.shape[0])\n",
        "\n",
        "# Train, vali and test indices.\n",
        "train_indices, test_indices = model_selection.train_test_split(\n",
        "    range(features.shape[0]), test_size=0.4)\n",
        "\n",
        "# Train features, labels and protected groups.\n",
        "train_set = {\n",
        "  'features': features[train_indices, :],\n",
        "  'labels': labels[train_indices],\n",
        "  'groups': groups[train_indices],\n",
        "  'queries': queries[train_indices],\n",
        "  'dimension': features.shape[-1],\n",
        "  'num_queries': num_queries\n",
        "}\n",
        "\n",
        "# Test features, labels and protected groups.\n",
        "test_set = {\n",
        "  'features': features[test_indices, :],\n",
        "  'labels': labels[test_indices],\n",
        "  'groups': groups[test_indices],\n",
        "  'queries': queries[test_indices],\n",
        "  'dimension': features.shape[-1],\n",
        "  'num_queries': num_queries\n",
        "}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JxBMPRJA2wvW",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation Metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q7WpTPkKAga-",
        "colab_type": "text"
      },
      "source": [
        "We will need functions to convert labeled data into paired data."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u0zUW2wEYMes",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def pair_pos_neg_docs(data):\n",
        "  # Returns a DataFrame of pairs of positive-negative docs from given DataFrame.\n",
        "  # Separate pos and neg docs.\n",
        "  pos_docs = data[data.label == 1]\n",
        "  if pos_docs.empty:\n",
        "    return\n",
        "  neg_docs = data[data.label == 0]\n",
        "  if neg_docs.empty:\n",
        "    return\n",
        "\n",
        "  # Include a merge key.\n",
        "  pos_docs.insert(0, 'merge_key', 0)\n",
        "  neg_docs.insert(0, 'merge_key', 0)\n",
        "\n",
        "  # Merge docs and drop merge key column.\n",
        "  pairs = pos_docs.merge(neg_docs, on='merge_key', how='outer',\n",
        "                         suffixes=('_pos', '_neg'))\n",
        "  pairs.drop(columns=['merge_key'], inplace=True)\n",
        "  return pairs\n",
        "\n",
        "\n",
        "def convert_labeled_to_paired_data(data_dict, index=None):\n",
        "  # Forms pairs of examples from each batch/query.\n",
        "\n",
        "  # Converts data arrays to pandas DataFrame with required column names and\n",
        "  # makes a call to convert_df_to_pairs and returns a dictionary.\n",
        "  features = data_dict['features']\n",
        "  labels = data_dict['labels']\n",
        "  groups = data_dict['groups']\n",
        "  queries = data_dict['queries']\n",
        "\n",
        "  if index is not None:\n",
        "    data_df = pd.DataFrame(features[queries == index, :])\n",
        "    data_df = data_df.assign(label=pd.DataFrame(labels[queries == index]))\n",
        "    data_df = data_df.assign(group=pd.DataFrame(groups[queries == index]))\n",
        "    data_df = data_df.assign(query_id=pd.DataFrame(queries[queries == index]))\n",
        "  else:\n",
        "    data_df = pd.DataFrame(features)\n",
        "    data_df = data_df.assign(label=pd.DataFrame(labels))\n",
        "    data_df = data_df.assign(group=pd.DataFrame(groups))\n",
        "    data_df = data_df.assign(query_id=pd.DataFrame(queries))\n",
        "\n",
        "  # Forms pairs of positive-negative docs for each query in given DataFrame\n",
        "  # if the DataFrame has a query_id column. Otherise forms pairs from all rows\n",
        "  # of the DataFrame.\n",
        "  data_pairs = data_df.groupby('query_id').apply(pair_pos_neg_docs)\n",
        "\n",
        "  # Create groups ndarray.\n",
        "  pos_groups = data_pairs['group_pos'].values.reshape(-1, 1)\n",
        "  neg_groups = data_pairs['group_neg'].values.reshape(-1, 1)\n",
        "  group_pairs = np.concatenate((pos_groups, neg_groups), axis=1)\n",
        "\n",
        "  # Create queries ndarray.\n",
        "  queries = data_pairs['query_id_pos'].values.reshape(-1,)\n",
        "\n",
        "  # Create features ndarray.\n",
        "  feature_names = data_df.columns\n",
        "  feature_names = feature_names.drop(['query_id', 'label'])\n",
        "  feature_names = feature_names.drop(['group'])\n",
        "\n",
        "  pos_features = data_pairs[[str(s) + '_pos' for s in feature_names]].values\n",
        "  pos_features = pos_features.reshape(-1, 1, len(feature_names))\n",
        "\n",
        "  neg_features = data_pairs[[str(s) + '_neg' for s in feature_names]].values\n",
        "  neg_features = neg_features.reshape(-1, 1, len(feature_names))\n",
        "\n",
        "  features_pairs = np.concatenate((pos_features, neg_features), axis=1)\n",
        "\n",
        "  # Paired data dict.\n",
        "  paired_data = {\n",
        "      'features': features_pairs, \n",
        "      'groups': group_pairs, \n",
        "      'queries': queries,\n",
        "      'dimension': data_dict['dimension'],\n",
        "      'num_queries': data_dict['num_queries']\n",
        "  }\n",
        "\n",
        "  return paired_data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H4HV7a7wq7rm",
        "colab_type": "text"
      },
      "source": [
        "We will also need functions to evaluate the pairwise error rates for a linear model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K8OQ4ado20p-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_mask(groups, pos_group, neg_group=None):\n",
        "  # Returns a boolean mask selecting positive-negative document pairs where \n",
        "  # the protected group for  the positive document is pos_group and \n",
        "  # the protected group for the negative document (if specified) is neg_group.\n",
        "  # Repeat group membership positive docs as many times as negative docs.\n",
        "  mask_pos = groups[:, 0] == pos_group\n",
        "  \n",
        "  if neg_group is None:\n",
        "    return mask_pos\n",
        "  else:\n",
        "    mask_neg = groups[:, 1] == neg_group\n",
        "    return mask_pos & mask_neg\n",
        "\n",
        "\n",
        "def error_rate(model, dataset):\n",
        "  # Returns error rate for Keras model on dataset.\n",
        "  d = dataset['dimension']\n",
        "  scores0 = model.predict(dataset['features'][:, 0, 0:d].reshape(-1, d))\n",
        "  scores1 = model.predict(dataset['features'][:, 1, 0:d].reshape(-1, d))\n",
        "  diff = scores0 - scores1  \n",
        "  return np.mean(diff.reshape((-1)) < 0)\n",
        "\n",
        "\n",
        "def group_error_rate(model, dataset, pos_group, neg_group=None):\n",
        "  # Returns error rate for Keras model on data set, considering only document \n",
        "  # pairs where the protected group for the positive document is pos_group, and  \n",
        "  # the protected group for the negative document (if specified) is neg_group.\n",
        "  d = dataset['dimension']\n",
        "  scores0 = model.predict(dataset['features'][:, 0, :].reshape(-1, d))\n",
        "  scores1 = model.predict(dataset['features'][:, 1, :].reshape(-1, d))\n",
        "  mask = get_mask(dataset['groups'], pos_group, neg_group)\n",
        "  diff = scores0 - scores1\n",
        "  diff = diff[mask > 0].reshape((-1))\n",
        "  return np.mean(diff < 0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kI8xNJDcpQYP",
        "colab_type": "text"
      },
      "source": [
        "## Create Linear Model\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lY4hvJAOra6s",
        "colab_type": "text"
      },
      "source": [
        "We then write a function to create the linear ranking model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eTQOebAepXSu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def create_ranking_model(features, dimension):\n",
        "  # Returns a linear Keras ranking model, and returns a nullary function \n",
        "  # returning predictions on the features.\n",
        "\n",
        "  # Linear ranking model with no hidden layers.\n",
        "  # No bias included as this is a ranking problem.\n",
        "  layers = []\n",
        "  # Input layer takes `dimension` inputs.\n",
        "  layers.append(tf.keras.Input(shape=(dimension,)))\n",
        "  layers.append(tf.keras.layers.Dense(1, use_bias=False)) \n",
        "  ranking_model = tf.keras.Sequential(layers)\n",
        "\n",
        "  # Create a nullary function that returns applies the linear model to the \n",
        "  # features and returns the tensor with the predictions.\n",
        "  def predictions():\n",
        "    scores0 = ranking_model(features()[:, 0, :].reshape(-1, dimension))\n",
        "    scores1 = ranking_model(features()[:, 1, :].reshape(-1, dimension))\n",
        "    return tf.reshape(scores0 - scores1, (-1,))\n",
        "\n",
        "  return ranking_model, predictions"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WIBvG3Arv7zR",
        "colab_type": "text"
      },
      "source": [
        "## Formulate Optimization Problem"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SfZd-XPt0A8E",
        "colab_type": "text"
      },
      "source": [
        "We are ready to formulate the constrained optimization problem using the TFCO library. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0AfVknixv9So",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def group_mask_fn(groups, pos_group, neg_group=None):\n",
        "  # Returns a nullary function returning group mask.\n",
        "  group_mask = lambda: np.reshape(\n",
        "      get_mask(groups(), pos_group, neg_group), (-1))\n",
        "  return group_mask\n",
        "\n",
        "\n",
        "def formulate_problem(\n",
        "    features, groups, dimension, constraint_groups=[], constraint_slack=None):\n",
        "  # Formulates a constrained problem that optimizes the error rate for a linear\n",
        "  # model on the specified dataset, subject to pairwise fairness constraints \n",
        "  # specified by the constraint_groups and the constraint_slack.\n",
        "  # \n",
        "  # Args:\n",
        "  #   features: Nullary function returning features\n",
        "  #   groups: Nullary function returning groups\n",
        "  #   labels: Nullary function returning labels\n",
        "  #   dimension: Input dimension for ranking model\n",
        "  #   constraint_groups: List containing tuples of the form \n",
        "  #     ((pos_group0, neg_group0), (pos_group1, neg_group1)), specifying the \n",
        "  #     group memberships for the document pairs to compare in the constraints.\n",
        "  #   constraint_slack: slackness '\\epsilon' allowed in the constraints.\n",
        "  # Returns:\n",
        "  #   A RateMinimizationProblem object, and a Keras ranking model.\n",
        "\n",
        "  # Set random seed for reproducibility.\n",
        "  random.seed(333333)\n",
        "  np.random.seed(121212)\n",
        "  tf.random.set_seed(212121)\n",
        "\n",
        "  # Create linear ranking model: we get back a Keras model and a nullary  \n",
        "  # function returning predictions on the features.\n",
        "  ranking_model, predictions = create_ranking_model(features, dimension)\n",
        "  \n",
        "  # Context for the optimization objective.\n",
        "  context = tfco.rate_context(predictions)\n",
        "  \n",
        "  # Constraint set.\n",
        "  constraint_set = []\n",
        "  \n",
        "  # Context for the constraints.\n",
        "  for ((pos_group0, neg_group0), (pos_group1, neg_group1)) in constraint_groups:\n",
        "    # Context for group 0.\n",
        "    group_mask0 = group_mask_fn(groups, pos_group0, neg_group0)\n",
        "    context_group0 = context.subset(group_mask0)\n",
        "\n",
        "    # Context for group 1.\n",
        "    group_mask1 = group_mask_fn(groups, pos_group1, neg_group1)\n",
        "    context_group1 = context.subset(group_mask1)\n",
        "\n",
        "    # Add constraints to constraint set.\n",
        "    constraint_set.append(\n",
        "        tfco.negative_prediction_rate(context_group0) <= (\n",
        "            tfco.negative_prediction_rate(context_group1) + constraint_slack))\n",
        "    constraint_set.append(\n",
        "        tfco.negative_prediction_rate(context_group1) <= (\n",
        "            tfco.negative_prediction_rate(context_group0) + constraint_slack))\n",
        "  \n",
        "  # Formulate constrained minimization problem.\n",
        "  problem = tfco.RateMinimizationProblem(\n",
        "      tfco.negative_prediction_rate(context), constraint_set)\n",
        "  \n",
        "  return problem, ranking_model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P1x4yEllRKjH",
        "colab_type": "text"
      },
      "source": [
        "## Train Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "16nddoPIrmuj",
        "colab_type": "text"
      },
      "source": [
        "The following function then trains the linear model by solving the above constrained optimization problem. We first provide a training function that performs one gradient update per query. There are three types of pairwise fairness criterion we handle (specified by 'constraint_type'), and assign the (pos_group, neg_group) pairs to compare accordingly."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Md5pDHyBRN83",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_model(train_set, params):\n",
        "  # Trains the model with stochastic updates (one query per updates).\n",
        "  #\n",
        "  # Args:\n",
        "  #   train_set: Dictionary of \"paired\" training data.\n",
        "  #   params: Dictionary of hyper-paramters for training.\n",
        "  #\n",
        "  # Returns:\n",
        "  #   Trained model, list of objectives, list of group constraint violations.\n",
        "\n",
        "  # Set up problem and model.\n",
        "  if params['constrained']:\n",
        "    # Constrained optimization.\n",
        "    if params['constraint_type'] == 'marginal_equal_opportunity':\n",
        "      constraint_groups = [((0, None), (1, None))]\n",
        "    elif params['constraint_type'] == 'cross_group_equal_opportunity':\n",
        "      constraint_groups = [((0, 1), (1, 0))]\n",
        "    else:\n",
        "      constraint_groups = [((0, 1), (1, 0)), ((0, 0), (1, 1))]\n",
        "  else:\n",
        "    # Unconstrained optimization.\n",
        "    constraint_groups = []\n",
        "\n",
        "  # Dictionary that will hold batch features pairs, group pairs and labels for \n",
        "  # current batch. We include one query per-batch. \n",
        "  paired_batch = {}\n",
        "  batch_index = 0  # Index of current query.\n",
        "\n",
        "  # Data functions.\n",
        "  features = lambda: paired_batch['features']\n",
        "  groups = lambda: paired_batch['groups'] \n",
        "\n",
        "  # Create ranking model and constrained optimization problem.\n",
        "  problem, ranking_model = formulate_problem(\n",
        "      features, groups, train_set['dimension'], constraint_groups, \n",
        "      params['constraint_slack'])\n",
        "  \n",
        "  # Create a loss function for the problem.\n",
        "  lagrangian_loss, update_ops, multipliers_variables = (\n",
        "      tfco.create_lagrangian_loss(problem, dual_scale=params['dual_scale']))\n",
        "\n",
        "  # Create optimizer\n",
        "  optimizer = tf.keras.optimizers.Adagrad(learning_rate=params['learning_rate'])\n",
        "  \n",
        "  # List of trainable variables.\n",
        "  var_list = (\n",
        "      ranking_model.trainable_weights + problem.trainable_variables + \n",
        "      [multipliers_variables])\n",
        "  \n",
        "  # List of objectives, group constraint violations.\n",
        "  # violations, and snapshot of models during course of training.\n",
        "  objectives = []\n",
        "  group_violations = []\n",
        "  models = []\n",
        "\n",
        "  features = train_set['features']\n",
        "  queries = train_set['queries']\n",
        "  groups = train_set['groups']\n",
        "\n",
        "  print()\n",
        "  # Run loops * iterations_per_loop full batch iterations.\n",
        "  for ii in range(params['loops']):\n",
        "    for jj in range(params['iterations_per_loop']):\n",
        "      # Populate paired_batch dict with all pairs for current query. The batch\n",
        "      # index is the same as the current query index.\n",
        "      paired_batch = {\n",
        "          'features': features[queries == batch_index],\n",
        "          'groups': groups[queries == batch_index]\n",
        "      }\n",
        "\n",
        "      # Optimize loss.\n",
        "      update_ops()\n",
        "      optimizer.minimize(lagrangian_loss, var_list=var_list)\n",
        "\n",
        "      # Update batch_index, and cycle back once last query is reached.\n",
        "      batch_index = (batch_index + 1) % train_set['num_queries']\n",
        "    \n",
        "    # Snap shot current model.\n",
        "    model_copy = tf.keras.models.clone_model(ranking_model)\n",
        "    model_copy.set_weights(ranking_model.get_weights())\n",
        "    models.append(model_copy)\n",
        "\n",
        "    # Evaluate metrics for snapshotted model. \n",
        "    error, gerr, group_viol = evaluate_results(\n",
        "        ranking_model, train_set, params)\n",
        "    objectives.append(error)\n",
        "    group_violations.append(\n",
        "        [x - params['constraint_slack'] for x in group_viol])\n",
        "\n",
        "    sys.stdout.write(\n",
        "        '\\r Loop %d: error = %.3f, max constraint violation = %.3f' % \n",
        "        (ii, objectives[-1], max(group_violations[-1])))\n",
        "  print()\n",
        "  \n",
        "  if params['constrained']:\n",
        "    # Find model iterate that trades-off between objective and group violations.\n",
        "    best_index = tfco.find_best_candidate_index(\n",
        "        np.array(objectives), np.array(group_violations), rank_objectives=False)\n",
        "  else:\n",
        "    # Find model iterate that achieves lowest objective.\n",
        "    best_index = np.argmin(objectives)\n",
        "\n",
        "  return models[best_index]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WxFJV0tvKvyR",
        "colab_type": "text"
      },
      "source": [
        "## Summarize and Plot Results"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i7In7Ra7M_S7",
        "colab_type": "text"
      },
      "source": [
        "Having trained a model, we will need functions to summarize the various evaluation metrics."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CBl5KfEOPApl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def evaluate_results(model, test_set, params):\n",
        "  # Returns overall, group error rates, group-level constraint violations.\n",
        "  if params['constraint_type'] == 'marginal_equal_opportunity':\n",
        "    g0_error = group_error_rate(model, test_set, 0)\n",
        "    g1_error = group_error_rate(model, test_set, 1)\n",
        "    group_violations = [g0_error - g1_error, g1_error - g0_error]\n",
        "    return (error_rate(model, test_set), [g0_error, g1_error], \n",
        "            group_violations)\n",
        "  else:\n",
        "    g00_error = group_error_rate(model, test_set, 0, 0)\n",
        "    g01_error = group_error_rate(model, test_set, 0, 1)\n",
        "    g10_error = group_error_rate(model, test_set, 1, 1)\n",
        "    g11_error = group_error_rate(model, test_set, 1, 1)\n",
        "    group_violations_offdiag = [g01_error - g10_error, g10_error - g01_error]\n",
        "    group_violations_diag = [g00_error - g11_error, g11_error - g00_error]\n",
        "\n",
        "    if params['constraint_type'] == 'cross_group_equal_opportunity':\n",
        "      return (error_rate(model, test_set), \n",
        "              [[g00_error, g01_error], [g10_error, g11_error]], \n",
        "              group_violations_offdiag)\n",
        "    else:\n",
        "      return (error_rate(model, test_set), \n",
        "              [[g00_error, g01_error], [g10_error, g11_error]], \n",
        "              group_violations_offdiag + group_violations_diag)\n",
        "    \n",
        "\n",
        "def display_results(\n",
        "    model, test_set, params, method, error_type, show_header=False):\n",
        "  # Prints evaluation results for model on test data.\n",
        "  error, group_error, diffs = evaluate_results(model, test_set, params)\n",
        "\n",
        "  if params['constraint_type'] == 'marginal_equal_opportunity':\n",
        "    if show_header:\n",
        "      print('\\nMethod\\t\\t\\tError\\t\\tOverall\\t\\tGroup 0\\t\\tGroup 1\\t\\tDiff')\n",
        "    print('%s\\t%s\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f' % (\n",
        "        method, error_type, error, group_error[0], group_error[1], \n",
        "        np.max(diffs)))\n",
        "  elif params['constraint_type'] == 'cross_group_equal_opportunity':\n",
        "    if show_header:\n",
        "      print('\\nMethod\\t\\t\\tError\\t\\tOverall\\t\\tGroup 0/1\\tGroup 1/0\\tDiff')\n",
        "    print('%s\\t%s\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f' % (\n",
        "        method, error_type, error, group_error[0][1], group_error[1][0], \n",
        "        np.max(diffs)))\n",
        "  else:\n",
        "    if show_header:\n",
        "      print('\\nMethod\\t\\t\\tError\\t\\tOverall\\t\\tGroup 0/1\\tGroup 1/0\\t' +\n",
        "            'Group 0/0\\tGroup 1/1\\tDiff')\n",
        "    print('%s\\t%s\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f\\t\\t%.3f' % (\n",
        "        method, error_type, error, group_error[0][1], group_error[1][0], \n",
        "        group_error[0][0], group_error[1][1], np.max(diffs)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PQR0nnORRedG",
        "colab_type": "text"
      },
      "source": [
        "# Experimental Results"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jTYOW_EOsrWV",
        "colab_type": "text"
      },
      "source": [
        "We now run experiments with two types of pairwise fairness criteria: (1) marginal_equal_opportunity and (2) pairwise equal opportunity. In each case, we compare an unconstrained model trained to optimize the error rate and a constrained model trained with pairwise fairness constraints.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qMsrOfh7zkbE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Convert train/test set to paired data for later evaluation.\n",
        "paired_train_set = convert_labeled_to_paired_data(train_set)\n",
        "paired_test_set = convert_labeled_to_paired_data(test_set)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jqxzaPTEwEIn",
        "colab_type": "text"
      },
      "source": [
        "\n",
        "## (1) Marginal Equal Opportunity\n",
        "\n",
        "\n",
        "For a ranking model $f: \\mathbb{R}^d \\rightarrow \\mathbb{R}$, recall:\n",
        "- $err(f)$ as the pairwise ranking error for model $f$ over all pairs of positive and negative documents\n",
        "$$\n",
        "err(f) ~=~ \\mathbf{E}\\big[\\mathbb{I}\\big(f(x) < f(x')\\big) \\,\\big|\\, y = 1,~ y' = 0\\big]\n",
        "$$\n",
        "\n",
        "and we additionally define:\n",
        "\n",
        "- $err_i(f)$ as the row-marginal pairwise error over positive-negative document pairs where the pos. document is from group $i$, and the neg. document is from either groups\n",
        "\n",
        "$$\n",
        "err_i(f) = \\mathbf{E}\\big[\\mathbb{I}\\big(f(x) < f(x')\\big) \\,\\big|\\, y = 1,~ y' = 0,~ grp(x) = i\\big]\n",
        "$$\n",
        "\n",
        "The constrained optimization problem we solve constraints the row-marginal pairwise errors to be similar:\n",
        "\n",
        "$$min_f\\;err(f)$$\n",
        "\n",
        "$$\\text{s.t.   }\\;|err_0(f) - err_1(f)| \\leq 0.05$$\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1JsuylHbjrBX",
        "colab_type": "code",
        "outputId": "cbf0f2a0-5cd8-44af-c1f7-05d2b9f1a72b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 230
        }
      },
      "source": [
        "# Model hyper-parameters.\n",
        "model_params = {\n",
        "    'loops': 10, \n",
        "    'iterations_per_loop': 250, \n",
        "    'learning_rate': 0.1,\n",
        "    'constraint_type': 'marginal_equal_opportunity', \n",
        "    'constraint_slack': 0.05,\n",
        "    'dual_scale': 0.1}\n",
        "\n",
        "# Unconstrained optimization.\n",
        "model_params['constrained'] = False\n",
        "model_unc  = train_model(paired_train_set, model_params)\n",
        "display_results(model_unc, paired_train_set, model_params, 'Unconstrained     ', \n",
        "                'Train', show_header=True)\n",
        "display_results(model_unc, paired_test_set, model_params,  'Unconstrained     ', \n",
        "                'Test')\n",
        "\n",
        "# Constrained optimization with TFCO.\n",
        "model_params['constrained'] = True\n",
        "model_con  = train_model(paired_train_set, model_params)\n",
        "display_results(model_con, paired_train_set, model_params, 'Constrained     ', \n",
        "                'Train', show_header=True)\n",
        "display_results(model_con, paired_test_set, model_params, 'Constrained     ', \n",
        "                'Test')"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            " Loop 9: error = 0.057, max constraint violation = 0.007\n",
            "\n",
            "Method\t\t\tError\t\tOverall\t\tGroup 0\t\tGroup 1\t\tDiff\n",
            "Unconstrained     \tTrain\t\t0.057\t\t0.093\t\t0.036\t\t0.057\n",
            "Unconstrained     \tTest\t\t0.079\t\t0.149\t\t0.043\t\t0.106\n",
            "\n",
            " Loop 9: error = 0.056, max constraint violation = 0.002\n",
            "\n",
            "Method\t\t\tError\t\tOverall\t\tGroup 0\t\tGroup 1\t\tDiff\n",
            "Constrained     \tTrain\t\t0.063\t\t0.093\t\t0.044\t\t0.049\n",
            "Constrained     \tTest\t\t0.078\t\t0.140\t\t0.047\t\t0.093\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CorY2URkop1Y"
      },
      "source": [
        "## (2) Pairwise Equal Opportunity\n",
        "\n",
        "Recall that we denote\n",
        " $err_{i,j}(f)$ as the ranking error over positive-negative document pairs where the pos. document is from group $i$, and the neg. document is from group $j$.\n",
        "$$\n",
        "err_{i, j}(f) ~=~ \\mathbf{E}\\big[\\mathbb{I}\\big(f(x) < f(x')\\big) \\,\\big|\\, y = 1,~ y' = 0,~ grp(x) = i, ~grp(x') = j\\big]\n",
        "$$\n",
        "\n",
        "\n",
        "We first constrain only the cross-group errors, highlighted below.\n",
        "\n",
        "<br>\n",
        "<table border='1' bordercolor='black'>\n",
        "  <tr >\n",
        "     <td bgcolor='white'> </td>\n",
        "     <td bgcolor='white'> </td>\n",
        "     <td bgcolor='white'  colspan=2 align=center><b>Negative</b></td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td bgcolor='white'></td>\n",
        "    <td bgcolor='white'></td>\n",
        "    <td>Group 0</td>\n",
        "    <td>Group 1</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td bgcolor='white' rowspan=2><b>Positive</b></td>\n",
        "    <td bgcolor='white'>Group 0</td>\n",
        "    <td bgcolor='white'>$err_{0,0}$</td>\n",
        "    <td bgcolor='white'>$\\mathbf{err_{0,1}}$</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>Group 1</td>\n",
        "     <td bgcolor='white'>$\\mathbf{err_{1,0}}$</td>\n",
        "      <td bgcolor='white'>$err_{1,1}$</td>\n",
        "  </tr>\n",
        "</table>\n",
        "<br>\n",
        "\n",
        "The optimization problem we solve constraints the cross-group pairwise errors to be similar:\n",
        "\n",
        "$$min_f\\; err(f)$$\n",
        "$$\\text{s.t. }\\;\\; |err_{0,1}(f) - err_{1,0}(f)| \\leq 0.05$$\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "vqat8pXHStjw",
        "outputId": "b8d45ebf-9d98-47a3-9ced-4c1acee31307",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 230
        }
      },
      "source": [
        "# Model hyper-parameters.\n",
        "model_params = {\n",
        "    'loops': 10, \n",
        "    'iterations_per_loop': 250, \n",
        "    'learning_rate': 0.1,\n",
        "    'constraint_type': 'cross_group_equal_opportunity', \n",
        "    'constraint_slack': 0.05,\n",
        "    'dual_scale': 0.1}\n",
        "\n",
        "# Unconstrained optimization.\n",
        "model_params['constrained'] = False\n",
        "model_unc  = train_model(paired_train_set, model_params)\n",
        "display_results(model_unc, paired_train_set, model_params, 'Unconstrained     ', \n",
        "                'Train', show_header=True)\n",
        "display_results(model_unc, paired_test_set, model_params,  'Unconstrained     ', \n",
        "                'Test')\n",
        "\n",
        "# Constrained optimization with TFCO.\n",
        "model_params['constrained'] = True\n",
        "model_con  = train_model(paired_train_set, model_params)\n",
        "display_results(model_con, paired_train_set, model_params, 'Constrained     ', \n",
        "                'Train', show_header=True)\n",
        "display_results(model_con, paired_test_set, model_params, 'Constrained     ', \n",
        "                'Test')"
      ],
      "execution_count": 45,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            " Loop 9: error = 0.057, max constraint violation = 0.109\n",
            "\n",
            "Method\t\t\tError\t\tOverall\t\tGroup 0/1\tGroup 1/0\tDiff\n",
            "Unconstrained     \tTrain\t\t0.057\t\t0.289\t\t0.130\t\t0.159\n",
            "Unconstrained     \tTest\t\t0.079\t\t0.333\t\t0.135\t\t0.198\n",
            "\n",
            " Loop 9: error = 0.074, max constraint violation = -0.041\n",
            "\n",
            "Method\t\t\tError\t\tOverall\t\tGroup 0/1\tGroup 1/0\tDiff\n",
            "Constrained     \tTrain\t\t0.074\t\t0.117\t\t0.126\t\t0.009\n",
            "Constrained     \tTest\t\t0.105\t\t0.186\t\t0.147\t\t0.039\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}